#!/usr/bin/env python
# coding: utf-8
#copyRight by heibanke
#如需转载请注明出处
#<<用Python做深度学习2-caffe>>
#http://study.163.com/course/courseMain.htm?courseId=1003491001

import caffe
import matplotlib.pyplot as plt
from pylab import *
from caffe import layers as L
from caffe import params as P
import caffe.draw
from caffe.proto import caffe_pb2
from google.protobuf import text_format
data_path = "/Users/rogerluo/Desktop/pyopnecv/caffedemo/verifycode/"
train_net_file = 'lenet_train.prototxt'
test_net_file = 'lenet_test.prototxt'
solver_file = "lenet_solver.prototxt"
net_file = data_path+train_net_file
predictpath = data_path+"train/1-0-7.jpg"

# def net(datafile, mean_file, batch_size):
#     n=caffe.NetSpec()
#     n.data,n.label =L.Data(source=datafile,backend = P.Data.LMDB, batch_size=batch_size, ntop=2, transform_param=dict(scale=1.0/255.0, mean_file=mean_file))
#     n.conv1 = L.Convolution(n.data,num_output=20,kernel_size=5,stride=1,weight_filler=dict(type='xavier'),bias_filler=dict(type='constant'),param=[{"lr_mult":1},{"lr_mult": 2}])
#     n.pool1 = L.Pooling(n.conv1,pool=P.Pooling.MAX,kernel_size=2,stride=2)
#     #n.relu1 = L.ReLU(n.conv1,in_place=True)
#     n.conv2 = L.Convolution(n.pool1,num_output=50,kernel_size=5,stride=1,weight_filler=dict(type='xavier'),bias_filler=dict(type='constant'),param=[{"lr_mult":1},{"lr_mult": 2}])
#     n.pool2 = L.Pooling(n.conv2,pool=P.Pooling.MAX,kernel_size=2,stride=2)
#     n.ip1 = L.InnerProduct(n.pool2, num_output=500, weight_filler=dict(type='xavier'),bias_filler=dict(type='constant'),param=[{"lr_mult":1},{"lr_mult": 2}])
#     n.relu1 = L.ReLU(n.ip1,in_place=True)
#     n.ip2 = L.InnerProduct(n.ip1,num_output=56, weight_filler=dict(type='xavier'),bias_filler=dict(type='constant'),param=[{"lr_mult":1},{"lr_mult": 2}])
#     n.accu = L.Accuracy(n.ip2,n.label,include={'phase':caffe.TEST})
#     n.loss = L.SoftmaxWithLoss(n.ip2, n.label)
#     return n.to_proto()
#
# ### net file generate #####
#
# with open( train_net_file, 'w') as f:
#     f.write(str(net(data_path+'train_lmdb',  data_path+'mean.binaryproto', 64)))
# with open( test_net_file, 'w') as f:
#     f.write(str(net(data_path+'test_lmdb',  data_path+'mean.binaryproto', 40)))


# solver file generate ######
# from caffe.proto import caffe_pb2
# s = caffe_pb2.SolverParameter()
#
# s.train_net = train_net_file
# s.test_net.append(test_net_file)
# s.test_interval = 500
# s.test_iter.append(10)
# s.display = 100
# s.max_iter = 10000
# s.weight_decay = 0.005
# s.base_lr = 0.01
# s.momentum = 0.9
# s.snapshot=5000
# s.snapshot_prefix="verify_code"
# s.lr_policy = "inv"
# s.gamma = 0.0001
# s.solver_mode = caffe_pb2.SolverParameter.CPU
#
# with open(solver_file, 'w') as f:
#     f.write(str(s))

#
# ### iter to calculate the models weight #####
solver = caffe.get_solver(solver_file)


niter = 101
train_loss = zeros(niter)
test_acc = zeros(niter)

output = zeros((niter, 8, 10))
# The main solver loop
for it in range(niter):
    solver.step(1)  # SGD by Caffe
    train_loss[it] = solver.net.blobs['loss'].data
    test_acc[it] = solver.test_nets[0].blobs['accu'].data



#-------------------------predict-----------------------------
# plt.subplot(221)
# print solver.net.blobs['data'].data[:8, 0].transpose(1, 0, 2).reshape(26, 8*22)
plt.imshow(solver.net.blobs['data'].data[:8, 0].transpose(1, 0, 2).reshape(26, 8*22), cmap='gray')
axis('off')
print 'train labels:', solver.net.blobs['label'].data[:8]

# for i in range(8):
#     figure(figsize=(2, 2))
#     imshow(solver.test_nets[0].blobs['data'].data[i, 0], cmap='gray')
#     figure(figsize=(10, 2))
#     imshow(output[:65, i].T, interpolation='nearest', cmap='gray')
#     xlabel('iteration')
#     ylabel('label')
plt.show()
# image = caffe.io.load_image(predictpath)
# transformed_image = solver.transformer.preprocess('data', image)
# plt.imshow(image)
# solver.net.blobs['data'].data[...] = transformed_image
# output = solver.net.forward()
# output_prob = output['prob'][0]
# print 'predicted class is:', output_prob.argmax()

def print_net_shape(net):
    print "======data and diff output shape======"
    for layer_name, blob in net.blobs.iteritems():
        print layer_name + ' out \t' + str(blob.data.shape)
        print layer_name + ' diff\t' + str(blob.diff.shape)

    print "======   weight and bias shape  ======"
    for layer_name, param in net.params.iteritems():
        print layer_name + ' weight\t' + str(param[0].data.shape), str(param[1].data.shape)
        print layer_name + ' diff  \t' + str(param[0].diff.shape), str(param[1].diff.shape)

def draw_net(net_file, jpg_file):
    net = caffe_pb2.NetParameter()
    text_format.Merge(open(net_file).read(), net)
    caffe.draw.draw_net_to_file(net, jpg_file, 'BT')








# print_net_shape(solver.net)
# draw_net(net_file, "a.jpg")
# #output graph
# _, ax1 = subplots()
# ax2 = ax1.twinx()
# ax1.plot(arange(niter), train_loss)
# ax2.plot(arange(niter), test_acc, 'r')
# ax1.set_xlabel('iteration')
# ax1.set_ylabel('train loss')
# ax2.set_ylabel('test accuracy')
# _.savefig('converge01.png')
